
# torch and torchvision
import torch

# custom libs
from networks.mnist.models_mnist import LeNet_mnist, Base
from networks.cifar10.models_cifar10 import ResNet18_cifar10, BasicBlock

# ------------------------------------------------------------------------------
#   Networks
# ------------------------------------------------------------------------------
def load_network(dataset, netname, intname='output'):

    # base networks for the MNIST dataset
    if 'mnist' == dataset:
        if 'lenet' == netname:
            return LeNet_mnist(num_classes=10)
        elif 'base' == netname:
            return Base()
        else:
            assert False, ('Error: invalid network name [{}]'.format(netname))

    elif 'cifar10' == dataset:
        if 'resnet18' == netname:
            return ResNet18_cifar10(BasicBlock, [2,2,2,2])
        else:
            assert False, ('Error: invalid network name [{}]'.format(netname))

    # TODO - define more network per dataset in here.

    # Undefined dataset
    else:
        assert False, ('Error: invalid dataset name [{}]'.format(dataset))

def load_trained_network(net, cuda, fpath):
    # Print the name of the network
    model_dict = torch.load(fpath) if cuda else \
                 torch.load(fpath, map_location=lambda storage, loc: storage)
    net.load_state_dict(model_dict)
    # done.
